import os

from core.models.oracle_mdl import OracleModel
from core.components.logger import Logger
from core.utils.general_utils import AttrDict
from core.configs.default_data_configs.dmcontrol import data_spec
from core.components.evaluator import DummyEvaluator, ImageEvaluator, MultiImageEvaluator
from core.data.src.data_loaders import DMControlRescaleDataset


current_dir = os.path.dirname(os.path.realpath(__file__))


configuration = {
    'model': OracleModel,
    'model_test': OracleModel,
    'logger': Logger,
    'logger_test': Logger,
    'evaluator': DummyEvaluator,
    'data_dir': os.path.join(os.environ['DATA_DIR'], './dmcontrol/walker/expert_L10'),
    'num_epochs': 100,
    'epoch_cycles_train': 10,
    'batch_size': 128,
    'discount_factor': 0.99,
    'n_frames': 3,
    'lr': 1e-4,
}
configuration = AttrDict(configuration)

model_config = {
    'action_dim': 6,
    'state_dim': 18,
    'img_sz': 64,
    'nz_enc': 256,
    'nz_mid': 256,
    'input_nc': 9,
    'normalization': 'none',
    'discount': configuration.discount_factor,
}

# Dataset
data_config = AttrDict()
data_config.dataset_spec = data_spec
data_config.dataset_spec.delta_t = 1
data_config.dataset_spec.task_names = ['walk-expert', 'stand-expert', 'backward-expert']
